Source code for hysop.operator.base.spatial_filtering

# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
@file advection.py
RestrictionFilter operator generator.
"""
import numpy as np
from hysop.constants import Implementation
from hysop.methods import Remesh
from hysop.numerics.remesh.remesh import RemeshKernel
from hysop.tools.io_utils import IOParams
from hysop.tools.htypes import check_instance, to_list, first_not_None, InstanceOf
from hysop.tools.numpywrappers import npw
from hysop.tools.decorators import debug
from hysop.tools.numerics import find_common_dtype
from hysop.tools.spectral_utils import SpectralTransformUtils
from hysop.tools.method_utils import PolynomialInterpolationMethod
from hysop.fields.continuous_field import Field, ScalarField
from hysop.parameters.scalar_parameter import ScalarParameter
from hysop.topology.cartesian_descriptor import CartesianTopologyDescriptors
from hysop.core.graph.node_generator import ComputationalGraphNodeGenerator
from hysop.core.graph.computational_node_frontend import ComputationalGraphNodeFrontend
from hysop.core.memory.memory_request import MemoryRequest
from hysop.operator.base.spectral_operator import SpectralOperatorBase


[docs] class SpatialFilterBase: """ Common base implementation for lowpass spatial filtering: small grid -> coarse grid """ def __new__(cls, input_field, output_field, input_topo, output_topo, **kwds): return super().__new__(cls, input_fields=None, output_fields=None, **kwds) def __init__(self, input_field, output_field, input_topo, output_topo, **kwds): check_instance(input_field, ScalarField) check_instance(output_field, ScalarField) check_instance(input_topo, CartesianTopologyDescriptors) check_instance(output_topo, CartesianTopologyDescriptors) super().__init__( input_fields={input_field: input_topo}, output_fields={output_field: output_topo}, **kwds, ) Fin = input_field Fout = output_field assert Fin.dim == Fout.dim assert (Fin.lboundaries == Fout.lboundaries).all() assert (Fin.rboundaries == Fout.rboundaries).all() assert (Fin.periodicity == Fout.periodicity).all() self.Fin = Fin self.Fout = Fout self.dim = Fin.dim self.dtype = find_common_dtype(Fin.dtype, Fout.dtype) self.iratio = None # will be set in get_field_requirements self.grid_ratio = None # will be set in discretize
[docs] @debug def discretize(self): if self.discretized: return super().discretize() dFin = self.get_input_discrete_field(self.Fin) dFout = self.get_output_discrete_field(self.Fout) grid_ratio = dFin.topology_state.transposed(self.iratio) self.dFin = dFin self.dFout = dFout self.grid_ratio = grid_ratio
[docs] @classmethod def supports_multiple_field_topologies(cls): return True
[docs] @classmethod def supports_mpi(cls): return True
[docs] def get_preserved_input_fields(self): return {self.Fin}
[docs] class RestrictionFilterBase(SpatialFilterBase):
[docs] @debug def get_field_requirements(self): requirements = super().get_field_requirements() dim = self.Fin.dim Fin_topo, Fin_requirements = requirements.get_input_requirement(self.Fin) try: Fin_dx = Fin_topo.space_step except AttributeError: Fin_dx = Fin_topo.mesh.space_step Fout_topo, Fout_requirements = requirements.get_output_requirement(self.Fout) try: Fout_dx = Fout_topo.space_step except AttributeError: Fout_dx = Fout_topo.mesh.space_step ratio = Fout_dx / Fin_dx msg = f"Destination grid is finer than source grid: {ratio}" assert (ratio >= 1.0).all(), msg iratio = ratio.astype(npw.int32) msg = f"Grid ratio is not an integer on at least one axis: {ratio}" assert (ratio == iratio).all(), msg self.iratio = tuple(iratio.tolist()) return requirements
[docs] class InterpolationFilterBase(SpatialFilterBase):
[docs] @debug def get_field_requirements(self): requirements = super().get_field_requirements() dim = self.Fin.dim Fin_topo, Fin_requirements = requirements.get_input_requirement(self.Fin) try: Fin_dx = Fin_topo.space_step except AttributeError: Fin_dx = Fin_topo.mesh.space_step Fout_topo, Fout_requirements = requirements.get_output_requirement(self.Fout) try: Fout_dx = Fout_topo.space_step except AttributeError: Fout_dx = Fout_topo.mesh.space_step ratio = Fin_dx / Fout_dx msg = f"Source grid is finer than destination grid: {ratio}" assert (ratio >= 1.0).all(), msg iratio = ratio.astype(npw.int32) msg = f"Grid ratio is not an integer on at least one axis: {ratio}" assert (ratio == iratio).all(), msg self.iratio = tuple(iratio.tolist()) return requirements
[docs] class SpectralRestrictionFilterBase(RestrictionFilterBase, SpectralOperatorBase): """ Base implementation for lowpass spatial filtering: small grid -> coarse grid using the spectral method. """ @debug def __new__(cls, plot_input_energy=None, plot_output_energy=None, **kwds): return super().__new__(cls, **kwds) @debug def __init__(self, plot_input_energy=None, plot_output_energy=None, **kwds): """ Initialize a SpectralRestrictionFilterBase. Parameters ---------- plot_input_energy: IOParams, optional, defaults to None Plot input field energy in a custom file. plot_output_energy: IOParams, optional, defaults to None Plot output field energy in a custom file. Notes ----- IOParams filename is formatted before being used: {fname} is replaced with field name {ite} is replaced with simulation iteration id If None is passed, no plots are generated. """ check_instance(plot_input_energy, IOParams, allow_none=True) check_instance(plot_output_energy, IOParams, allow_none=True) super().__init__(**kwds) Fin, Fout = self.Fin, self.Fout # check that boundary conditions are matching msg = ( "Input field {l}boundaries {} mismatch with output field {l}boundaries {}." ) assert (Fin.lboundaries == Fout.lboundaries).all(), msg.format( Fin.lboundaries, Fout.lboundaries, l="l" ) assert (Fin.rboundaries == Fout.rboundaries).all(), msg.format( Fin.rboundaries, Fout.rboundaries, l="r" ) # build spectral transforms tg_fine = self.new_transform_group(mem_tag="FINE") tg_coarse = self.new_transform_group(mem_tag="COARSE") Ft = tg_fine.require_forward_transform( Fin, custom_output_buffer="auto", plot_energy=plot_input_energy ) Bt = tg_coarse.require_backward_transform( Fout, custom_input_buffer="B0", plot_energy=plot_output_energy ) self.tg_fine = tg_fine self.tg_coarse = tg_coarse self.Ft = Ft self.Bt = Bt
[docs] @debug def discretize(self): if self.discretized: return super().discretize() dFin, dFout = self.dFin, self.dFout msg = "Compute resolution of coarse mesh {}::{} is greater than compute resolution of fine mesh {}::{}." msg = msg.format( self.Fin.name, dFin.compute_resolution, self.Fout.name, dFout.compute_resolution, ) assert (dFin.compute_resolution >= dFout.compute_resolution).all(), msg
[docs] def setup(self, work): super().setup(work) self.FIN = self.Ft.output_buffer self.FOUT = self.Bt.input_buffer self.fslices = self._generate_filter_slices() self.scaling = self._compute_scaling_coefficient()
def _generate_filter_slices(self): src_slices = [[]] dst_slices = [[]] transforms = tuple(self.Ft.transforms[i] for i in self.Ft.output_axes) for N, n, tr in zip(self.FIN.shape, self.FOUT.shape, transforms): assert len(src_slices) == len(dst_slices) assert n <= N if SpectralTransformUtils.is_C2C(tr): left_src_slices = [l[:] for l in src_slices] right_src_slices = [l[:] for l in src_slices] lsrc = slice(0, (n + 1) // 2, 1) rsrc = slice(N - n // 2, N, 1) for lslc, rslc in zip(left_src_slices, right_src_slices): lslc.append(lsrc) rslc.append(rsrc) src_slices = left_src_slices + right_src_slices left_dst_slices = [l[:] for l in dst_slices] right_dst_slices = [l[:] for l in dst_slices] ldst = slice(0, (n + 1) // 2, 1) rdst = slice(n - n // 2, n, 1) for lslc, rslc in zip(left_dst_slices, right_dst_slices): lslc.append(ldst) rslc.append(rdst) dst_slices = left_dst_slices + right_dst_slices else: src = slice(0, n, 1) dst = slice(0, n, 1) for src_slc, dst_slc in zip(src_slices, dst_slices): src_slc.append(src) dst_slc.append(dst) src_slices = tuple(tuple(_) for _ in src_slices) dst_slices = tuple(tuple(_) for _ in dst_slices) return (src_slices, dst_slices) def _compute_scaling_coefficient(self): # scaling can depend on the fft backend so we bruteforce it # in every backend msg = "_compute_scaling_coefficient() has not been implemented for operator {}." raise NotImplementedError(msg.format(type(self)))
[docs] class RemeshRestrictionFilterBase(RestrictionFilterBase): """ Base implementation for lowpass spatial filtering: small grid -> coarse grid using remeshing kernels. """ __default_method = { Remesh: Remesh.L2_1, } __available_methods = { Remesh: (InstanceOf(Remesh), InstanceOf(RemeshKernel)), }
[docs] @classmethod def default_method(cls): dm = super().default_method() dm.update(cls.__default_method) return dm
[docs] @classmethod def available_methods(cls): am = super().available_methods() am.update(cls.__available_methods) return am
[docs] @debug def handle_method(self, method): super().handle_method(method) remesh_kernel = method.pop(Remesh) if isinstance(remesh_kernel, Remesh): remesh_kernel = RemeshKernel.from_enum(remesh_kernel) self.remesh_kernel = remesh_kernel
@classmethod def _remesh_ghosts(cls, remesh_kernel): """Return the minimum number of ghosts for remeshed scalars.""" assert remesh_kernel.n >= 1, "Bad remeshing kernel." if remesh_kernel.n > 1: assert remesh_kernel.n % 2 == 0, "Odd remeshing kernel moments." min_ghosts = int(remesh_kernel.n // 2) + 1 return min_ghosts
[docs] @debug def get_field_requirements(self): requirements = super().get_field_requirements() iratio = self.iratio remesh_ghosts = self._remesh_ghosts(self.remesh_kernel) fine_grid_ghosts = tuple(np.multiply(iratio, remesh_ghosts) - 1) Fin_topo, Fin_requirements = requirements.get_input_requirement(self.Fin) Fin_requirements.min_ghosts = fine_grid_ghosts self.remesh_ghosts = remesh_ghosts self.fine_grid_ghosts = fine_grid_ghosts return requirements
[docs] def compute_weights(self, iratio, product=True): iratio_np = np.asarray(iratio) assert (iratio_np >= 1).all() remesh_kernel = self.remesh_kernel p = remesh_kernel.n // 2 + 1 shape = 2 * p * iratio_np - 1 weights = npw.zeros(dtype=npw.float64, shape=shape) nz_weights = {} for idx in npw.ndindex(*shape): X = (npw.asarray(idx, dtype=npw.float64) + 1) / iratio_np - p if product: W = npw.prod(remesh_kernel(X)) else: # this does not seem to work because the sum of the weights is ~1e-5 R = npw.sqrt(npw.dot(X, X)) W = remesh_kernel(R) weights[idx] = W if W != 0: nz_weights[idx] = W Ws = weights.sum() weights = weights / Ws nz_weights = {k: v / Ws for (k, v) in nz_weights.items()} assert abs(weights.sum() - 1.0) < 1e-8, weights.sum() assert abs(npw.sum(nz_weights.values()) - 1.0) < 1e-8, npw.sum( nz_weights.values() ) self.weights = weights self.nz_weights = nz_weights
[docs] @debug def discretize(self): if self.discretized: return super().discretize() dFin, dFout = self.dFin, self.dFout grid_ratio = self.grid_ratio self.compute_weights(grid_ratio) remesh_ghosts = self.remesh_ghosts fine_grid_ghosts = np.multiply(grid_ratio, remesh_ghosts) - 1 fin = dFin.sdata[dFin.local_slices(ghosts=fine_grid_ghosts)] fout = dFout.compute_buffers[0] self.fin, self.fout = fin, fout
[docs] class SubgridRestrictionFilterBase(RestrictionFilterBase): """ Base implementation for lowpass spatial filtering: small grid -> coarse grid using subgrid """
[docs] @debug def discretize(self): if self.discretized: return super().discretize() dFin, dFout = self.dFin, self.dFout grid_ratio = self.grid_ratio view = tuple(slice(None, None, r) for r in grid_ratio) fin = dFin.compute_buffers[0][view] fout = dFout.compute_buffers[0] msg = "Something went wrong during slicing: fin.shape={}, fout.shape={}" msg = msg.format(fin.shape, fout.shape) assert fin.shape == fout.shape, msg assert npw.prod(grid_ratio) == npw.prod(self.iratio), msg self.fin, self.fout = fin, fout
[docs] class PolynomialInterpolationFilterBase( PolynomialInterpolationMethod, InterpolationFilterBase ): """ Base implementation for polynomial interpolation. """
[docs] @debug def get_field_requirements(self): reqs = super().get_field_requirements() required_input_ghosts = np.add( self.polynomial_interpolator.ghosts, self.Fin.periodicity ) Fin_topo, Fin_requirements = reqs.get_input_requirement(self.Fin) Fin_requirements.min_ghosts = required_input_ghosts self.required_input_ghosts = required_input_ghosts return reqs
[docs] def discretize(self): if self.discretized: return super().discretize() dFin, dFout = self.dFin, self.dFout ghosts = self.dFin.topology_state.transposed(self.required_input_ghosts) psi = self.polynomial_interpolator.generate_subgrid_interpolator( grid_ratio=self.grid_ratio ) self.subgrid_interpolator = psi self.fin = dFin.sdata[dFin.local_slices(ghosts=ghosts)].handle self.fout = dFout.sdata[dFout.compute_slices].handle self.iter_shape = self.dFin.compute_resolution + 1 - self.dFin.periodicity
[docs] class PolynomialRestrictionFilterBase( PolynomialInterpolationMethod, RestrictionFilterBase ): """ Base implementation for polynomial interpolation. """
[docs] @debug def get_field_requirements(self): reqs = super().get_field_requirements() iratio = self.iratio pghosts = self.polynomial_interpolator.ghosts ghosts = np.add(np.multiply(iratio, np.add(pghosts, 1)), -1) Fin_topo, Fin_requirements = reqs.get_input_requirement(self.Fin) Fin_requirements.min_ghosts = ghosts self.required_input_ghosts = ghosts return reqs
[docs] def discretize(self): if self.discretized: return super().discretize() dFin, dFout = self.dFin, self.dFout ghosts = self.dFin.topology_state.transposed(self.required_input_ghosts) psr = self.polynomial_interpolator.generate_subgrid_interpolator( grid_ratio=self.grid_ratio ).generate_subgrid_restrictor() assert all(psr.ghosts == ghosts) self.subgrid_restrictor = psr self.fin = dFin.sdata[dFin.local_slices(ghosts=ghosts)].handle self.fout = dFout.sdata[dFout.compute_slices].handle self.iter_shape = self.dFout.compute_resolution